
import torch
from torch import nn
import copy
import torch.optim as optim
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_printoptions(threshold=float('inf'))
class CB(nn.Module):
    def __init__(self, cfg):
        super(CB, self).__init__()
        self.cfg = cfg
        self.sknum = cfg.sk_num
        self.eplen = cfg.ep_len
        self.dv = cfg.cb.dv
        self.d = cfg.d

    @staticmethod
    def zero_topk_elements(tensor, knn):
        n = tensor.size(0)  
        num_cols = tensor.size(1)  
    
        _, indices = torch.topk(tensor, n-knn, dim=1)
    
        tensor.scatter_(1, indices, 0)
        return tensor

    def forward(self, some_ary, idx_set):
        some_ary = some_ary.reshape(self.eplen+1, self.sknum, 2).transpose(0, 1)
        some_ary = some_ary.reshape(-1, 2)

        diffs = some_ary[None, :, :] - some_ary[:, None, :].detach().clone()
        dist_sq = (diffs ** 2).sum(-1)
    
        # Compute dimensional distance
        dim_distance = 1 - torch.exp(-dist_sq / self.dv) #sknum*(eplen+1) sknum*(eplen + 1)
    
        # Set diagonal blocks to zero
        for i in range(self.sknum):
            dim_distance[(self.eplen+1) * i : (self.eplen+1) * (i+1), (self.eplen+1) * i : (self.eplen+1) * (i+1)] = 0

        # Sum of squared dimensional distances to compute reward
        reward = torch.sum(dim_distance ** 2, -1)
        reward = reward.reshape(self.sknum, self.eplen+1)
        reward = reward[:, 1:].transpose(0, 1)
    
        return reward

    def feature_train(self, ary, idx):
        pass

class Large_CB(nn.Module):
    def __init__(self, cfg):
        super(Large_CB, self).__init__()
        self.cfg = cfg
        self.sknum = cfg.sk_num
        self.eplen = cfg.ep_len
        self.dv = cfg.cb.dv
        self.d = cfg.d

    @staticmethod
    def zero_topk_elements(tensor, knn):
        n = tensor.size(0) 
        num_cols = tensor.size(1)  
    
        _, indices = torch.topk(tensor, n-knn, dim=1)
    
        tensor.scatter_(1, indices, 0)
    
        return tensor

    def forward(self, some_ary, idx_set):
        some_ary = some_ary.reshape(self.eplen+1, self.sknum, 2).transpose(0, 1)
        some_ary = some_ary.reshape(-1, 2)

        diffs = some_ary[None, :, :] - some_ary[:, None, :]
        dist_sq = (diffs ** 2).sum(-1)

        mask = torch.zeros_like(dist_sq)
        for i in range(self.sknum):
            mask[(self.eplen+1) * i : (self.eplen+1) * (i+1), (self.eplen+1) * i : (self.eplen+1) * (i+1)] = True

        dist_sq = torch.where(mask.bool(), torch.zeros_like(dist_sq), dist_sq)


        dist_sq_exp = torch.exp(-dist_sq/self.dv) #distribution values
        dist_sq_exp = F.normalize(dist_sq_exp, p=2, dim=-1)

        formatted_values = [f"{value.item():.2f}" for value in dist_sq_exp[0]]
        #print(dist_sq_exp)


        real_diff = dist_sq_exp.unsqueeze(0) - dist_sq_exp.unsqueeze(1)
        real_diff_sq = (real_diff ** 2).sum(-1) 
        #print(real_diff_sq)
        real_diff_sq = torch.where(mask.bool(), torch.zeros_like(real_diff_sq), real_diff_sq)

        # Sum of squared dimensional distances to compute reward
        reward = torch.mean(real_diff_sq, dim=-1)
        reward = reward.reshape(self.sknum, self.eplen+1)
        reward = reward[:, 1:].transpose(0, 1)
        print("rsea",reward.size())
    
        return reward

    def _forward(self, some_ary, idx_set):
        some_ary = some_ary.reshape(self.eplen+1, self.sknum, 2).transpose(0, 1)
        some_ary = some_ary.reshape(-1, 2)

        # Compute all differences using broadcasting
        diffs = some_ary[None, :, :] - some_ary[:, None, :]
        dist_sq = (diffs ** 2).sum(-1)

        mask = torch.zeros_like(dist_sq)
        for i in range(self.sknum):
            mask[(self.eplen+1) * i : (self.eplen+1) * (i+1), (self.eplen+1) * i : (self.eplen+1) * (i+1)] = True

        #dist_sq = torch.where(mask.bool(), torch.zeros_like(dist_sq), dist_sq)
        dist_sq_exp = torch.exp(-dist_sq/self.dv) #distribution values

        pool = nn.AvgPool2d(kernel_size=(self.eplen+1), stride=(self.eplen+1))
        dist_sq_exp = pool(dist_sq_exp.unsqueeze(0).unsqueeze(0)).squeeze()#*((self.eplen + 1)**2)
        torch.diagonal(dist_sq_exp).fill_(1)

        #dist_sq_exp = dist_sq_exp /torch.sum(torch.abs(dist_sq_exp), -1)
        dist_sq_exp = F.normalize(dist_sq_exp, p=2, dim=-1)

        #formatted_values = [f"{value.item():.2f}" for value in dist_sq_exp[0]]

        real_diff = dist_sq_exp.unsqueeze(0) - dist_sq_exp.unsqueeze(1)
        real_diff_sq = (torch.abs(real_diff)**2).sum(-1) # ̰  Ŀ

        #real_diff_sq = torch.where(mask.bool(), torch.zeros_like(real_diff_sq), real_diff_sq)

        # Sum of squared dimensional distances to compute reward
        reward = torch.mean(real_diff_sq, dim=-1)
        #reward = reward.reshape(self.sknum, self.eplen+1)
        #reward = reward[:, 1:].transpose(0, 1)
        print("rsea",reward.size())
    
        return reward

    def feature_train(self, ary, idx):
        pass

class Cic(nn.Module):
    def __init__(self, cfg):
        super(Cic, self).__init__()
        
        self.query = nn.Sequential(
            nn.Linear(2*cfg.cic.state_embedding_dim, cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(cfg.hidden_size, cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(cfg.hidden_size, cfg.sk_num)
        ).to(device)

        self.key = nn.Sequential(
            nn.Linear(cfg.index_dim, cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(cfg.hidden_size, cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(cfg.hidden_size, cfg.sk_num)
        ).to(device)

        self.embedding = nn.Sequential(
            nn.Linear(cfg.state_dim, cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(cfg.hidden_size, cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(cfg.hidden_size, cfg.cic.state_embedding_dim)
        ).to(device)

        query_parameters = self.query.parameters()
        key_parameters = self.key.parameters()
        se_parameters = self.embedding.parameters()
        all_parameters = list(query_parameters) + list(key_parameters) + list(se_parameters)
        self.cfg = cfg
        self.optimizer = optim.Adam(all_parameters, lr=cfg.cont_lr, weight_decay=1e-6)
        self.LLU = nn.LeakyReLU(0.1)
        self.state_size = cfg.state_dim
        self.eplen = cfg.ep_len
        self.sknum = cfg.sk_num
        self.d = cfg.d
        self.softmax = nn.Softmax(dim=-1)

    @staticmethod
    def zero_topk_elements(tensor, knn):
        n = tensor.size(0)  
        num_cols = tensor.size(1)  
    
 
        _, indices = torch.topk(tensor, n-knn, dim=1)
    

        tensor.scatter_(1, indices, 0)
    
        return tensor

    def forward(self, traj_list, skill_set):

        grad_coord = traj_list[:, :, :2]
        grad_coord = grad_coord.transpose(0,1) # skill ep 2
        pre_state = grad_coord[:, :-1]
        pre_state = self.embedding(pre_state)
        post_state = grad_coord[:, 1:]
        post_state = self.embedding(post_state)
        traj_pair = torch.cat((pre_state, post_state), -1) #sk ep 34
        # Compute key outputs for each skill
        key_outputs = self.key(skill_set)  # Expected size: [sk, 20]
        key_outputs = F.normalize(key_outputs, p=2, dim=-1)
        # Compute query outputs for each time step in each trajectory
        # traj_pair is [16, eplen, 34], need query outputs of size [16, eplen, 20]
        # size: [16, eplen, 20]
        query_outputs = self.query(traj_pair) #sk ep 20
        query_outputs = F.normalize(query_outputs, p=2, dim=-1)
        # Compute positive contributions
        positive = torch.bmm(query_outputs, key_outputs.unsqueeze(-1)).squeeze()
        #sk ep
        negative = torch.bmm(query_outputs, key_outputs.transpose(0, 1).unsqueeze(0).repeat(self.sknum, 1, 1))
        #sk ep sk
        negative = torch.exp(negative)


        negative_mean = torch.log(torch.sum(negative, dim=(0,1)) / (self.sknum*self.eplen))
        disc_reward = positive - negative_mean.reshape(self.sknum, 1) #sk, ep
        disc_reward = disc_reward.transpose(0, 1) #ep sk
        #disc_reward = torch.sum(positive) - torch.sum(negative_mean)  # 16


        ex_traj_list = traj_list.reshape(-1, 2)
        traj_emb = self.embedding(ex_traj_list)
        diff = traj_emb.unsqueeze(0) - traj_emb.unsqueeze(1) #
        diff_sq = torch.sum(torch.square(diff), -1)

        diff_sq = self.zero_topk_elements(diff_sq, self.cfg.cic.knn)
        diff = torch.sum((diff_sq/len(ex_traj_list)), -1)  #
        explo_reward = torch.log(diff + 1)# ep*sk
        explo_reward = explo_reward.reshape(self.eplen+1, self.sknum)

        reward = disc_reward*(1-self.cfg.cic.explo_ratio) + explo_reward[1:]*self.cfg.cic.explo_ratio #ep, sk



        #exponents = torch.arange(0, length).float()


        #tensor = torch.pow(base, exponents).to(device)

        #reward = reward * tensor.unsqueeze(-1)


        return torch.sum(reward)


    def feature_train(self, traj_list, skill_set):
        traj_list = traj_list.detach().clone()
        self.train()
        grad_coord = traj_list[:, :, :2]
        grad_coord = grad_coord.transpose(0,1) # skill ep 2
        pre_state = grad_coord[:, :-1]
        pre_state = self.embedding(pre_state)
        post_state = grad_coord[:, 1:]
        post_state = self.embedding(post_state)
        traj_pair = torch.cat((pre_state, post_state), -1) #sk ep 34
        # Compute key outputs for each skill
        key_outputs = self.key(skill_set)  # Expected size: [sk, 20]
        key_outputs = F.normalize(key_outputs, p=2, dim=-1)
        # Compute query outputs for each time step in each trajectory
        # traj_pair is [16, eplen, 34], need query outputs of size [16, eplen, 20]
        # size: [16, eplen, 20]
        query_outputs = self.query(traj_pair) #sk ep 20
        query_outputs = F.normalize(query_outputs, p=2, dim=-1)
        # Compute positive contributions
        positive = torch.bmm(query_outputs, key_outputs.unsqueeze(-1)).squeeze()
        #sk ep
        negative = torch.bmm(query_outputs, key_outputs.transpose(0, 1).unsqueeze(0).repeat(self.sknum, 1, 1))
        #sk ep sk
        negative = torch.exp(negative)


        negative_mean = torch.log(torch.sum(negative, dim=(0,1)) / (self.sknum*self.eplen))
        disc_reward = positive - negative_mean.reshape(self.sknum, 1) #sk, ep
        disc_reward = disc_reward.transpose(0, 1) #ep sk
        #disc_reward = torch.sum(positive) - torch.sum(negative_mean)  # 16

        loss = -torch.sum(disc_reward)
        print("loss = ", loss)
        self.optimizer.zero_grad()
        loss.backward(retain_graph = True)

        self.optimizer.step()
        return loss.item()


class Diayn(nn.Module):
    def __init__(self, cfg):
        super(Diayn, self).__init__()
        self.cfg = cfg

        self.disc = nn.Sequential(
            nn.Linear(cfg.state_dim + cfg.index_dim, cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(cfg.hidden_size, cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(cfg.hidden_size, cfg.sk_num)
        ).to(device)

        
        self.optimizer = optim.Adam(self.parameters(), lr=cfg.cont_lr, weight_decay=1e-6)
        
        self.state_size = cfg.state_dim
        self.eplen = cfg.ep_len
        self.sknum = cfg.sk_num
        self.d = cfg.d
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, grad_traj, skill_set):
        post_state = grad_traj[1:] #eplen, sknum, 3 grad
        skill_set = skill_set.unsqueeze(0).repeat(self.eplen, 1, 1)
        #print("pp", post_state.size())
        #print(skill_set.size())
        net_input = torch.cat((post_state, skill_set), dim = -1)
        output = self.disc(net_input)
        output = output.reshape(self.eplen, -1)
        output = self.softmax(output)
        output = output.reshape(self.eplen, self.sknum, self.sknum)
        output = F.normalize(output, p=2, dim=-1)
        output = self.softmax(output)
        reward = torch.diagonal(output, dim1=-1, dim2=-2)
        #max1, _ = torch.max(output, dim=1)
        #max2, _ = torch.max(output, dim=2)
        #reward = torch.log(max1) + torch.log(max2)

        return torch.sum(reward)

    @staticmethod
    def zero_topk_elements(tensor, knn):
        n = tensor.size(0)  
        num_cols = tensor.size(1)  
    
        _, indices = torch.topk(tensor, n-knn, dim=1)
    
        tensor.scatter_(1, indices, 0)
    
        return tensor

    def feature_train(self, grad_traj, skill_set):
        post_state = grad_traj.detach().clone() #eplen, sknum, 3 grad
        self.train()
        loss = -self(post_state, skill_set)
        print("loss = ", loss)
        self.optimizer.zero_grad()
        loss.backward(retain_graph = True)

        self.optimizer.step()
        return loss.item()


class Dads(nn.Module):

    def __init__(self, cfg):
        super(Dads, self).__init__()
        self.cfg = cfg
        self.dads = nn.Sequential(
            nn.Linear(cfg.state_dim + cfg.index_dim, cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(cfg.hidden_size, cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(cfg.hidden_size, cfg.dads.dads_expert * cfg.state_dim)
        ).to(device)
        
        self.optimizer = optim.Adam(self.parameters(), lr=cfg.cont_lr)

        self.state_size = cfg.state_dim
        self.eplen = cfg.ep_len
        self.sknum = cfg.sk_num
        self.var = cfg.dads.dads_var
        self.d = cfg.d
        self.expert = cfg.dads.dads_expert

    def forward(self, grad_traj, skill_set):
        input_state = grad_traj[:-1].detach().clone() #eplen, sknum, 3 no grad
        next_state = grad_traj[1:] #eplen, sknum, 3 grad
        #eplen * sknum * (sksize + statesize) -> eplen * sknum * (4*statesize)
        #11,16,3
        #16,4

        total_prob = self.base(input_state, next_state, skill_set)
        base = torch.log(total_prob+0.0001)
        

        prob = self.positive(input_state, next_state, skill_set)
        forward = torch.log(prob+0.0001) #for feature train

        reward = forward - base #for reward

        #delete 4. eplen * sknum, statesize
        return torch.sum(reward) # eplen * sknum , 2
    
    def positive(self, input_state, next_state, skill_set):
        
        #eplen * sknum * (sksize + statesize) -> eplen * sknum * (4*statesize)
        #11,16,3
        #16,4

        skill_set = skill_set.unsqueeze(0).repeat(self.eplen, 1, 1)
        nn_input = torch.cat((input_state, skill_set), dim = -1)

        
        output = self.dads(nn_input) #state_diff_expert_pred
        output = output.reshape(self.eplen * self.sknum, self.expert, self.state_size)# pred state diff
        new_output = F.normalize(output, p=2, dim=-1, eps=1e-6)
        new_output = output#*0.7 + new_output*0.3

        state_diff = next_state - input_state
        state_diff = state_diff.reshape(self.eplen * self.sknum, 1, self.state_size)
        pred_diff = torch.sum((new_output - state_diff)**2, -1) #scalar square
        
        prob = torch.exp(-pred_diff/self.var)
        #print("prob",prob)
        prob = torch.mean(prob, dim=-1)


        #delete 4. eplen * sknum, statesize
        return prob # eplen * sknum , 2

    def base(self, input_state, next_state, skill_set):

        #11,16,3
        #16,4
        #eplen * sknum * sknum * (sksize + statesize) -> eplen * sknum * sknum * (4*statesize)
        
        skill_set = skill_set.unsqueeze(0).repeat(self.eplen, 1, 1)
        skill_set = skill_set.unsqueeze(1).repeat(1, self.sknum, 1, 1)

        input_state = input_state.unsqueeze(2).repeat(1, 1, self.sknum, 1)
        next_state = next_state.unsqueeze(2).repeat(1, 1, self.sknum, 1)
        
        nn_input = torch.cat((input_state, skill_set), dim = -1)

        output = self.dads(nn_input)
        output = output.reshape(self.eplen * self.sknum, self.sknum, self.expert, self.state_size)# pred state diff
        new_output = F.normalize(output, p=2, dim=-1, eps=1e-6)
        new_output = output#*0.7 + new_output*0.3


        state_diff = next_state - input_state
        state_diff = state_diff.reshape(self.eplen * self.sknum, self.sknum, 1, self.state_size)
        pred_diff = torch.sum((new_output - state_diff)**2, -1) #scalar square
        
        prob = torch.exp(-pred_diff/self.var)
        #print("totprob", prob)
        prob = torch.mean(prob, dim=-1)
        print("p2size", prob.size())
        #delete 4. eplen * sknum, statesize
        prob = torch.mean(prob, dim=-1)

        return prob

    def feature_train(self, grad_traj, skill_set):
        input_state = grad_traj[:-1].detach().clone() #eplen, sknum, 3 no grad
        next_state = grad_traj[1:].detach().clone() #eplen, sknum, 3 grad
        prob = self.positive(input_state, next_state, skill_set)
        self.train()
        loss = -torch.sum(prob)
        print("loss = ", loss)
        self.optimizer.zero_grad()
        loss.backward(retain_graph = True)
        self.optimizer.step()
        return loss.item()


    @staticmethod
    def zero_topk_elements(tensor, knn):
        n = tensor.size(0)  
        num_cols = tensor.size(1) 
    
        _, indices = torch.topk(tensor, n-knn, dim=1)
    
        tensor.scatter_(1, indices, 0)
    
        return tensor